import torch
from Conversion import *
from transformers import AutoTokenizer


# Download vocabulary from huggingface.co and cache.
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="../..")
tokenizer.pad_token = tokenizer.eos_token
seq_length=64
pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
text="Big Ben is probably the world's most famous clock. Big Ben is probably the world's most famous clock.  "
subtext="Big Ben is probably the world's most famous clock."
train_subsequence=tokenizer.encode(subtext)
input_ids = [tokenizer.encode(text, padding='max_length', max_length=seq_length)]


batch_size=1

constructed_gpt2, simulated_gpt2, model_config, config = Construct_NASgpt()


constructed_gpt2.eval()
simulated_gpt2.eval()

device=next(simulated_gpt2.parameters()).device
#input_ids = np.random.randint(model_config.vocab_size, size=(batch_size, seq_length))
input_ids = torch.tensor(input_ids, dtype=torch.int32).to(device)


bidirection_mask = torch.zeros((batch_size, seq_length)).to(device)
bidirection_mask[:, :len(train_subsequence)] = 1.

target = input_ids.detach().clone()
target[ torch.where(target == pad_token_id)  ] = -100
target[ torch.where(bidirection_mask == 1.)  ] = -100

with torch.no_grad():
    hidden_state, input_embeddings, wt_stack, original_loss, final_loss = constructed_gpt2.forward(input_ids, bidirection_mask, test_backward_pass=True, continue_from_first_forward_pass=False, labels=target, pad_token=pad_token_id)
    print ("Forward simulation loss", original_loss)
    print ("Forward-backward simulation loss", final_loss)    
    
    
    simulated_output = simulated_gpt2.forward(input_ids, output_hidden_states=True, labels=target.long())
    loss = simulated_output[0]
    print ("Simulated model loss", loss)




    if config.n_simulation_layers != -1:
        layer = config.n_simulation_layers
        hidden_state = simulated_output[-1][layer]
        embed_input  = simulated_gpt2.transformer.ln_f(hidden_state)

        optimizer = torch.optim.SGD(params=[p for n,p in simulated_gpt2.named_parameters()\
                                            if any(['transformer.h.'+str(layer_id) in n for layer_id in range(layer)])],\
                                    lr=config.inner_lr)

        loss = -torch.sum( torch.sum( embed_input * hidden_state * bidirection_mask.unsqueeze(dim=-1), axis=-1 ) )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        
        
#simulated_hidden_state = simulated_gpt2.forward(input_ids, output_hidden_states=True)

        